{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "+-----------+-----------+---+------------------+----------+------+------+-----------------+----+\n",
      "|  Ship_name|Cruise_line|Age|           Tonnage|passengers|length|cabins|passenger_density|crew|\n",
      "+-----------+-----------+---+------------------+----------+------+------+-----------------+----+\n",
      "|    Journey|    Azamara|  6|30.276999999999997|      6.94|  5.94|  3.55|            42.64|3.55|\n",
      "|      Quest|    Azamara|  6|30.276999999999997|      6.94|  5.94|  3.55|            42.64|3.55|\n",
      "|Celebration|   Carnival| 26|            47.262|     14.86|  7.22|  7.43|             31.8| 6.7|\n",
      "|   Conquest|   Carnival| 11|             110.0|     29.74|  9.53| 14.88|            36.99|19.1|\n",
      "|    Destiny|   Carnival| 17|           101.353|     26.42|  8.92| 13.21|            38.36|10.0|\n",
      "|    Ecstasy|   Carnival| 22|            70.367|     20.52|  8.55|  10.2|            34.29| 9.2|\n",
      "|    Elation|   Carnival| 15|            70.367|     20.52|  8.55|  10.2|            34.29| 9.2|\n",
      "|    Fantasy|   Carnival| 23|            70.367|     20.56|  8.55| 10.22|            34.23| 9.2|\n",
      "|Fascination|   Carnival| 19|            70.367|     20.52|  8.55|  10.2|            34.29| 9.2|\n",
      "|    Freedom|   Carnival|  6|110.23899999999999|      37.0|  9.51| 14.87|            29.79|11.5|\n",
      "|      Glory|   Carnival| 10|             110.0|     29.74|  9.51| 14.87|            36.99|11.6|\n",
      "|    Holiday|   Carnival| 28|            46.052|     14.52|  7.27|  7.26|            31.72| 6.6|\n",
      "|Imagination|   Carnival| 18|            70.367|     20.52|  8.55|  10.2|            34.29| 9.2|\n",
      "|Inspiration|   Carnival| 17|            70.367|     20.52|  8.55|  10.2|            34.29| 9.2|\n",
      "|     Legend|   Carnival| 11|              86.0|     21.24|  9.63| 10.62|            40.49| 9.3|\n",
      "|   Liberty*|   Carnival|  8|             110.0|     29.74|  9.51| 14.87|            36.99|11.6|\n",
      "|    Miracle|   Carnival|  9|              88.5|     21.24|  9.63| 10.62|            41.67|10.3|\n",
      "|   Paradise|   Carnival| 15|            70.367|     20.52|  8.55|  10.2|            34.29| 9.2|\n",
      "|      Pride|   Carnival| 12|              88.5|     21.24|  9.63| 11.62|            41.67| 9.3|\n",
      "|  Sensation|   Carnival| 20|            70.367|     20.52|  8.55|  10.2|            34.29| 9.2|\n",
      "+-----------+-----------+---+------------------+----------+------+------+-----------------+----+\n",
      "only showing top 20 rows\n",
      "\n"
     ]
    }
   ],
   "source": [
    "from pyspark.sql import SparkSession\n",
    "from pyspark.ml.regression import LinearRegression\n",
    "\n",
    "spark = SparkSession.builder.appName('CruiseLinear').getOrCreate()\n",
    "df = spark.read.csv('cruise_ship_info.csv', header = True, inferSchema=True)\n",
    "df.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "root\n",
      " |-- Ship_name: string (nullable = true)\n",
      " |-- Cruise_line: string (nullable = true)\n",
      " |-- Age: integer (nullable = true)\n",
      " |-- Tonnage: double (nullable = true)\n",
      " |-- passengers: double (nullable = true)\n",
      " |-- length: double (nullable = true)\n",
      " |-- cabins: double (nullable = true)\n",
      " |-- passenger_density: double (nullable = true)\n",
      " |-- crew: double (nullable = true)\n",
      "\n"
     ]
    }
   ],
   "source": [
    "df.printSchema()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Row(Ship_name='Journey', Cruise_line='Azamara', Age=6, Tonnage=30.276999999999997, passengers=6.94, length=5.94, cabins=3.55, passenger_density=42.64, crew=3.55)\n",
      "\n",
      "\n",
      "Row(Ship_name='Quest', Cruise_line='Azamara', Age=6, Tonnage=30.276999999999997, passengers=6.94, length=5.94, cabins=3.55, passenger_density=42.64, crew=3.55)\n",
      "\n",
      "\n",
      "Row(Ship_name='Celebration', Cruise_line='Carnival', Age=26, Tonnage=47.262, passengers=14.86, length=7.22, cabins=7.43, passenger_density=31.8, crew=6.7)\n",
      "\n",
      "\n",
      "Row(Ship_name='Conquest', Cruise_line='Carnival', Age=11, Tonnage=110.0, passengers=29.74, length=9.53, cabins=14.88, passenger_density=36.99, crew=19.1)\n",
      "\n",
      "\n",
      "Row(Ship_name='Destiny', Cruise_line='Carnival', Age=17, Tonnage=101.353, passengers=26.42, length=8.92, cabins=13.21, passenger_density=38.36, crew=10.0)\n",
      "\n",
      "\n"
     ]
    }
   ],
   "source": [
    "for ship in df.take(5):\n",
    "    print(ship)\n",
    "    print('\\n')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "+-----------------+-----+\n",
      "|      Cruise_line|count|\n",
      "+-----------------+-----+\n",
      "|            Costa|   11|\n",
      "|              P&O|    6|\n",
      "|           Cunard|    3|\n",
      "|Regent_Seven_Seas|    5|\n",
      "|              MSC|    8|\n",
      "|         Carnival|   22|\n",
      "|          Crystal|    2|\n",
      "|           Orient|    1|\n",
      "|         Princess|   17|\n",
      "|        Silversea|    4|\n",
      "|         Seabourn|    3|\n",
      "| Holland_American|   14|\n",
      "|         Windstar|    3|\n",
      "|           Disney|    2|\n",
      "|        Norwegian|   13|\n",
      "|          Oceania|    3|\n",
      "|          Azamara|    2|\n",
      "|        Celebrity|   10|\n",
      "|             Star|    6|\n",
      "|  Royal_Caribbean|   23|\n",
      "+-----------------+-----+\n",
      "\n"
     ]
    }
   ],
   "source": [
    "df.groupBy('Cruise_line').count().show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "[Row(Ship_name='Journey', Cruise_line='Azamara', Age=6, Tonnage=30.276999999999997, passengers=6.94, length=5.94, cabins=3.55, passenger_density=42.64, crew=3.55, Cruise_category=16.0)]"
      ]
     },
     "execution_count": 11,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "from pyspark.ml.feature import StringIndexer\n",
    "\n",
    "indexer = StringIndexer(inputCol = 'Cruise_line', outputCol = 'Cruise_category')\n",
    "indexed = indexer.fit(df).transform(df)\n",
    "indexed.take(1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "['Ship_name',\n",
       " 'Cruise_line',\n",
       " 'Age',\n",
       " 'Tonnage',\n",
       " 'passengers',\n",
       " 'length',\n",
       " 'cabins',\n",
       " 'passenger_density',\n",
       " 'crew',\n",
       " 'Cruise_category']"
      ]
     },
     "execution_count": 13,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "from pyspark.ml.linalg import Vectors\n",
    "from pyspark.ml.feature import VectorAssembler\n",
    "\n",
    "indexed.columns"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "+-----------+-----------+---+------------------+----------+------+------+-----------------+----+---------------+--------------------+\n",
      "|  Ship_name|Cruise_line|Age|           Tonnage|passengers|length|cabins|passenger_density|crew|Cruise_category|            features|\n",
      "+-----------+-----------+---+------------------+----------+------+------+-----------------+----+---------------+--------------------+\n",
      "|    Journey|    Azamara|  6|30.276999999999997|      6.94|  5.94|  3.55|            42.64|3.55|           16.0|[6.0,30.276999999...|\n",
      "|      Quest|    Azamara|  6|30.276999999999997|      6.94|  5.94|  3.55|            42.64|3.55|           16.0|[6.0,30.276999999...|\n",
      "|Celebration|   Carnival| 26|            47.262|     14.86|  7.22|  7.43|             31.8| 6.7|            1.0|[26.0,47.262,14.8...|\n",
      "+-----------+-----------+---+------------------+----------+------+------+-----------------+----+---------------+--------------------+\n",
      "only showing top 3 rows\n",
      "\n"
     ]
    }
   ],
   "source": [
    "assembler = VectorAssembler(inputCols = ['Age', 'Tonnage', 'passengers', 'length', 'cabins', 'passenger_density', 'Cruise_category'], outputCol = 'features')\n",
    "output = assembler.transform(indexed)\n",
    "output.show(3)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "[Row(Ship_name='Journey', Cruise_line='Azamara', Age=6, Tonnage=30.276999999999997, passengers=6.94, length=5.94, cabins=3.55, passenger_density=42.64, crew=3.55, Cruise_category=16.0, features=DenseVector([6.0, 30.277, 6.94, 5.94, 3.55, 42.64, 16.0]))]"
      ]
     },
     "execution_count": 16,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "output.take(1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "+--------------------+----+\n",
      "|            features|crew|\n",
      "+--------------------+----+\n",
      "|[6.0,30.276999999...|3.55|\n",
      "|[6.0,30.276999999...|3.55|\n",
      "|[26.0,47.262,14.8...| 6.7|\n",
      "+--------------------+----+\n",
      "only showing top 3 rows\n",
      "\n"
     ]
    }
   ],
   "source": [
    "output.select(['features', 'crew']).show(3)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "+--------------------+----+\n",
      "|            features|crew|\n",
      "+--------------------+----+\n",
      "|[6.0,30.276999999...|3.55|\n",
      "|[6.0,30.276999999...|3.55|\n",
      "|[26.0,47.262,14.8...| 6.7|\n",
      "+--------------------+----+\n",
      "only showing top 3 rows\n",
      "\n"
     ]
    }
   ],
   "source": [
    "final_data = output.select(['features', 'crew'])\n",
    "final_data.show(3)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "metadata": {},
   "outputs": [],
   "source": [
    "train, test = final_data.randomSplit([0.7, 0.3])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "+-------+------------------+\n",
      "|summary|              crew|\n",
      "+-------+------------------+\n",
      "|  count|               116|\n",
      "|   mean| 7.842068965517253|\n",
      "| stddev|3.4552569947183707|\n",
      "|    min|              0.59|\n",
      "|    max|              21.0|\n",
      "+-------+------------------+\n",
      "\n"
     ]
    }
   ],
   "source": [
    "train.describe().show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "+-------+------------------+\n",
      "|summary|              crew|\n",
      "+-------+------------------+\n",
      "|  count|                42|\n",
      "|   mean|7.6619047619047596|\n",
      "| stddev| 3.672975292988452|\n",
      "|    min|              0.59|\n",
      "|    max|              13.6|\n",
      "+-------+------------------+\n",
      "\n"
     ]
    }
   ],
   "source": [
    "test.describe().show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 22,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "0.8818512711220009"
      ]
     },
     "execution_count": 22,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "lr = LinearRegression(labelCol = 'crew')\n",
    "lr_model = lr.fit(train)\n",
    "result = lr_model.evaluate(test)\n",
    "result.rootMeanSquaredError"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 25,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "+-------+---------+-----------+------------------+------------------+-----------------+-----------------+------------------+-----------------+-----------------+\n",
      "|summary|Ship_name|Cruise_line|               Age|           Tonnage|       passengers|           length|            cabins|passenger_density|             crew|\n",
      "+-------+---------+-----------+------------------+------------------+-----------------+-----------------+------------------+-----------------+-----------------+\n",
      "|  count|      158|        158|               158|               158|              158|              158|               158|              158|              158|\n",
      "|   mean| Infinity|       null|15.689873417721518| 71.28467088607599|18.45740506329114|8.130632911392404| 8.830000000000005|39.90094936708861|7.794177215189873|\n",
      "| stddev|      NaN|       null| 7.615691058751413|37.229540025907866|9.677094775143416|1.793473548054825|4.4714172221480615| 8.63921711391542|3.503486564627034|\n",
      "|    min|Adventure|    Azamara|                 4|             2.329|             0.66|             2.79|              0.33|             17.7|             0.59|\n",
      "|    max|Zuiderdam|   Windstar|                48|             220.0|             54.0|            11.82|              27.0|            71.43|             21.0|\n",
      "+-------+---------+-----------+------------------+------------------+-----------------+-----------------+------------------+-----------------+-----------------+\n",
      "\n"
     ]
    }
   ],
   "source": [
    "df.describe().show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 23,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "0.9409499716842359"
      ]
     },
     "execution_count": 23,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "result.r2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 26,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "0.7776616643794887"
      ]
     },
     "execution_count": 26,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "result.meanSquaredError"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 27,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "0.6505290446343073"
      ]
     },
     "execution_count": 27,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "result.meanAbsoluteError"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 30,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "+----------------------+\n",
      "|corr(crew, passengers)|\n",
      "+----------------------+\n",
      "|    0.9152341306065384|\n",
      "+----------------------+\n",
      "\n"
     ]
    }
   ],
   "source": [
    "from pyspark.sql.functions import corr\n",
    "df.select(corr('crew', 'passengers')).show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 33,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "+------------------+\n",
      "|corr(crew, cabins)|\n",
      "+------------------+\n",
      "|0.9508226063578497|\n",
      "+------------------+\n",
      "\n"
     ]
    }
   ],
   "source": [
    "df.select(corr('crew', 'cabins')).show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "conda_python3",
   "language": "python",
   "name": "conda_python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.4"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
